diff --git a/examples/audio/transformer_asr.py b/examples/audio/transformer_asr.py
index 8cd3e04..5edf885 100644
--- a/examples/audio/transformer_asr.py
+++ b/examples/audio/transformer_asr.py
@@ -35,6 +35,7 @@ from glob import glob
 import tensorflow as tf
 from tensorflow import keras
 from tensorflow.keras import layers
+from sky_callback import SkyKerasCallback
 
 
 """
@@ -520,7 +521,7 @@ learning_rate = CustomSchedule(
 optimizer = keras.optimizers.Adam(learning_rate)
 model.compile(optimizer=optimizer, loss=loss_fn)
 
-history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)
+history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb, SkyKerasCallback()], epochs=1)
 
 """
 In practice, you should train for around 100 epochs or more.
